/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.test.checkpointing;

import org.apache.flink.api.common.functions.RichFilterFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.CheckpointListener;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.util.Collector;

import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/**
 * Integration test for the {@link CheckpointListener} interface. The test ensures that {@link
 * CheckpointListener#notifyCheckpointComplete(long)} is called for completed checkpoints, that it
 * is called at most once for any checkpoint id and that it is not called for a deliberately failed
 * checkpoint.
 *
 * <p>The topology tested here includes a number of {@link OneInputStreamOperator}s and a {@link
 * TwoInputStreamOperator}.
 *
 * <p>Note that as a result of doing the checks on the task level there is no way to verify that the
 * {@link CheckpointListener#notifyCheckpointComplete(long)} is called for every successfully
 * completed checkpoint.
 */
@SuppressWarnings("serial")
public class StreamCheckpointNotifierITCase extends AbstractTestBase {

    private static final Logger LOG = LoggerFactory.getLogger(StreamCheckpointNotifierITCase.class);

    private static final int PARALLELISM = 4;

    /**
     * Runs the following program.
     *
     * <pre>
     *     [ (source)->(filter) ] -> [ (co-map) ] -> [ (map) ] -> [ (groupBy/reduce)->(sink) ]
     * </pre>
     */
    @Test
    public void testProgram() {
        try {
            final StreamExecutionEnvironment env =
                    StreamExecutionEnvironment.getExecutionEnvironment();
            assertEquals("test setup broken", PARALLELISM, env.getParallelism());

            env.enableCheckpointing(500);
            env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 0L));

            final int numElements = 10000;
            final int numTaskTotal = PARALLELISM * 5;

            DataStream<Long> stream =
                    env.addSource(new GeneratingSourceFunction(numElements, numTaskTotal));

            stream
                    // -------------- first vertex, chained to the src ----------------
                    .filter(new LongRichFilterFunction())

                    // -------------- second vertex, applying the co-map ----------------
                    .connect(stream)
                    .flatMap(new LeftIdentityCoRichFlatMapFunction())

                    // -------------- third vertex - the stateful one that also fails
                    // ----------------
                    .map(new IdentityMapFunction())
                    .startNewChain()

                    // -------------- fourth vertex - reducer and the sink ----------------
                    .keyBy(0)
                    .reduce(new OnceFailingReducer(numElements))
                    .addSink(new DiscardingSink<Tuple1<Long>>());

            env.execute();

            final long failureCheckpointID = OnceFailingReducer.failureCheckpointID;
            assertNotEquals(0L, failureCheckpointID);

            List<List<Long>[]> allLists =
                    Arrays.asList(
                            GeneratingSourceFunction.COMPLETED_CHECKPOINTS,
                            LongRichFilterFunction.COMPLETED_CHECKPOINTS,
                            LeftIdentityCoRichFlatMapFunction.COMPLETED_CHECKPOINTS,
                            IdentityMapFunction.COMPLETED_CHECKPOINTS,
                            OnceFailingReducer.COMPLETED_CHECKPOINTS);

            for (List<Long>[] parallelNotifications : allLists) {
                for (List<Long> notifications : parallelNotifications) {

                    assertTrue(
                            "No checkpoint notification was received.", notifications.size() > 0);

                    assertFalse(
                            "Failure checkpoint was marked as completed.",
                            notifications.contains(failureCheckpointID));

                    assertFalse(
                            "No checkpoint received after failure.",
                            notifications.get(notifications.size() - 1) == failureCheckpointID);

                    assertTrue(
                            "Checkpoint notification was received multiple times",
                            notifications.size() == new HashSet<Long>(notifications).size());
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
            fail(e.getMessage());
        }
    }

    static List<Long>[] createCheckpointLists(int parallelism) {
        @SuppressWarnings({"unchecked", "rawtypes"})
        List<Long>[] lists = new List[parallelism];
        for (int i = 0; i < parallelism; i++) {
            lists[i] = new ArrayList<>();
        }
        return lists;
    }

    // --------------------------------------------------------------------------------------------
    //  Custom Functions
    // --------------------------------------------------------------------------------------------

    /**
     * Generates some Long values and as an implementation for the {@link CheckpointListener}
     * interface it stores all the checkpoint ids it has seen in a static list.
     */
    private static class GeneratingSourceFunction extends RichSourceFunction<Long>
            implements ParallelSourceFunction<Long>, CheckpointListener, ListCheckpointed<Integer> {

        static final List<Long>[] COMPLETED_CHECKPOINTS = createCheckpointLists(PARALLELISM);

        static AtomicLong numPostFailureNotifications = new AtomicLong();

        // operator behaviour
        private final long numElements;

        private final int notificationsToWaitFor;

        private int index;
        private int step;

        private volatile boolean notificationAlready;

        private volatile boolean isRunning = true;

        GeneratingSourceFunction(long numElements, int notificationsToWaitFor) {
            this.numElements = numElements;
            this.notificationsToWaitFor = notificationsToWaitFor;
        }

        @Override
        public void open(Configuration parameters) throws IOException {
            step = getRuntimeContext().getNumberOfParallelSubtasks();

            // if index has been restored, it is not 0 any more
            if (index == 0) {
                index = getRuntimeContext().getIndexOfThisSubtask();
            }
        }

        @Override
        public void run(SourceContext<Long> ctx) throws Exception {
            final Object lockingObject = ctx.getCheckpointLock();

            while (isRunning && index < numElements) {
                long result = index % 10;

                synchronized (lockingObject) {
                    index += step;
                    ctx.collect(result);
                }
            }

            // if the program goes fast and no notifications come through, we
            // wait until all tasks had a chance to see a notification
            while (isRunning && numPostFailureNotifications.get() < notificationsToWaitFor) {
                Thread.sleep(50);
            }
        }

        @Override
        public void cancel() {
            isRunning = false;
        }

        @Override
        public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception {
            return Collections.singletonList(this.index);
        }

        @Override
        public void restoreState(List<Integer> state) throws Exception {
            if (state.isEmpty() || state.size() > 1) {
                throw new RuntimeException(
                        "Test failed due to unexpected recovered state size " + state.size());
            }
            this.index = state.get(0);
        }

        @Override
        public void notifyCheckpointComplete(long checkpointId) {
            // record the ID of the completed checkpoint
            int partition = getRuntimeContext().getIndexOfThisSubtask();
            COMPLETED_CHECKPOINTS[partition].add(checkpointId);

            // if this is the first time we get a notification since the failure,
            // tell the source function
            if (OnceFailingReducer.hasFailed && !notificationAlready) {
                notificationAlready = true;
                GeneratingSourceFunction.numPostFailureNotifications.incrementAndGet();
            }
        }

        @Override
        public void notifyCheckpointAborted(long checkpointId) {}
    }

    /**
     * Identity transform on Long values wrapping the output in a tuple. As an implementation for
     * the {@link CheckpointListener} interface it stores all the checkpoint ids it has seen in a
     * static list.
     */
    private static class IdentityMapFunction extends RichMapFunction<Long, Tuple1<Long>>
            implements CheckpointListener {

        static final List<Long>[] COMPLETED_CHECKPOINTS = createCheckpointLists(PARALLELISM);

        private volatile boolean notificationAlready;

        @Override
        public Tuple1<Long> map(Long value) throws Exception {
            return Tuple1.of(value);
        }

        @Override
        public void notifyCheckpointComplete(long checkpointId) {
            // record the ID of the completed checkpoint
            int partition = getRuntimeContext().getIndexOfThisSubtask();
            COMPLETED_CHECKPOINTS[partition].add(checkpointId);

            // if this is the first time we get a notification since the failure,
            // tell the source function
            if (OnceFailingReducer.hasFailed && !notificationAlready) {
                notificationAlready = true;
                GeneratingSourceFunction.numPostFailureNotifications.incrementAndGet();
            }
        }

        @Override
        public void notifyCheckpointAborted(long checkpointId) {}
    }

    /**
     * Filter on Long values supposedly letting all values through. As an implementation for the
     * {@link CheckpointListener} interface it stores all the checkpoint ids it has seen in a static
     * list.
     */
    private static class LongRichFilterFunction extends RichFilterFunction<Long>
            implements CheckpointListener {

        static final List<Long>[] COMPLETED_CHECKPOINTS = createCheckpointLists(PARALLELISM);

        private volatile boolean notificationAlready;

        @Override
        public boolean filter(Long value) {
            return value < 100;
        }

        @Override
        public void notifyCheckpointComplete(long checkpointId) {
            // record the ID of the completed checkpoint
            int partition = getRuntimeContext().getIndexOfThisSubtask();
            COMPLETED_CHECKPOINTS[partition].add(checkpointId);

            // if this is the first time we get a notification since the failure,
            // tell the source function
            if (OnceFailingReducer.hasFailed && !notificationAlready) {
                notificationAlready = true;
                GeneratingSourceFunction.numPostFailureNotifications.incrementAndGet();
            }
        }

        @Override
        public void notifyCheckpointAborted(long checkpointId) {}
    }

    /**
     * CoFlatMap on Long values as identity transform on the left input, while ignoring the right.
     * As an implementation for the {@link CheckpointListener} interface it stores all the
     * checkpoint ids it has seen in a static list.
     */
    private static class LeftIdentityCoRichFlatMapFunction
            extends RichCoFlatMapFunction<Long, Long, Long> implements CheckpointListener {

        static final List<Long>[] COMPLETED_CHECKPOINTS = createCheckpointLists(PARALLELISM);

        private volatile boolean notificationAlready;

        @Override
        public void flatMap1(Long value, Collector<Long> out) {
            out.collect(value);
        }

        @Override
        public void flatMap2(Long value, Collector<Long> out) {
            // we ignore the values from the second input
        }

        @Override
        public void notifyCheckpointComplete(long checkpointId) {
            // record the ID of the completed checkpoint
            int partition = getRuntimeContext().getIndexOfThisSubtask();
            COMPLETED_CHECKPOINTS[partition].add(checkpointId);

            // if this is the first time we get a notification since the failure,
            // tell the source function
            if (OnceFailingReducer.hasFailed && !notificationAlready) {
                notificationAlready = true;
                GeneratingSourceFunction.numPostFailureNotifications.incrementAndGet();
            }
        }

        @Override
        public void notifyCheckpointAborted(long checkpointId) {}
    }

    /** Reducer that causes one failure between seeing 40% to 70% of the records. */
    private static class OnceFailingReducer extends RichReduceFunction<Tuple1<Long>>
            implements ListCheckpointed<Long>, CheckpointListener {
        static volatile boolean hasFailed = false;
        static volatile long failureCheckpointID;

        static final List<Long>[] COMPLETED_CHECKPOINTS = createCheckpointLists(PARALLELISM);

        private final long failurePos;

        private volatile long count;

        private volatile boolean notificationAlready;

        OnceFailingReducer(long numElements) {
            this.failurePos = (long) (0.5 * numElements / PARALLELISM);
        }

        @Override
        public Tuple1<Long> reduce(Tuple1<Long> value1, Tuple1<Long> value2) {
            count++;
            if (count >= failurePos && getRuntimeContext().getIndexOfThisSubtask() == 0) {
                LOG.info(">>>>>>>>>>>>>>>>> Reached failing position <<<<<<<<<<<<<<<<<<<<<");
            }

            value1.f0 += value2.f0;
            return value1;
        }

        @Override
        public List<Long> snapshotState(long checkpointId, long timestamp) throws Exception {
            if (!hasFailed
                    && count >= failurePos
                    && getRuntimeContext().getIndexOfThisSubtask() == 0) {
                LOG.info(">>>>>>>>>>>>>>>>> Throwing Exception <<<<<<<<<<<<<<<<<<<<<");
                hasFailed = true;
                failureCheckpointID = checkpointId;
                throw new Exception("Test Failure");
            }
            return Collections.singletonList(this.count);
        }

        @Override
        public void restoreState(List<Long> state) throws Exception {
            if (state.isEmpty() || state.size() > 1) {
                throw new RuntimeException(
                        "Test failed due to unexpected recovered state size " + state.size());
            }
            this.count = state.get(0);
        }

        @Override
        public void notifyCheckpointComplete(long checkpointId) {
            // record the ID of the completed checkpoint
            int partition = getRuntimeContext().getIndexOfThisSubtask();
            COMPLETED_CHECKPOINTS[partition].add(checkpointId);

            // if this is the first time we get a notification since the failure,
            // tell the source function
            if (OnceFailingReducer.hasFailed && !notificationAlready) {
                notificationAlready = true;
                GeneratingSourceFunction.numPostFailureNotifications.incrementAndGet();
            }
        }

        @Override
        public void notifyCheckpointAborted(long checkpointId) {}
    }
}
